{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb0e4886-c222-4183-89a7-f077e3b1a105",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import os\n",
    "import sys\n",
    "\n",
    "# Add the project's files to the python path\n",
    "# file_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))  # for .py script\n",
    "file_path = os.path.dirname(os.path.abspath(''))  # for .ipynb notebook\n",
    "sys.path.append(file_path)\n",
    "\n",
    "import hydra\n",
    "from src.utils import init_config, compute_panoptic_metrics, \\\n",
    "    compute_panoptic_metrics_s3dis_6fold, grid_search_panoptic_partition, \\\n",
    "    oracle_superpoint_clustering\n",
    "import torch\n",
    "from src.transforms import *\n",
    "from src.utils.widgets import *\n",
    "from src.data import *\n",
    "\n",
    "# Very ugly fix to ignore lightning's warning messages about the\n",
    "# trainer and modules not being connected\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "452559c9-8fbc-4685-bbaf-c6e4ac24cbc7",
   "metadata": {},
   "source": [
    "## Select your device, experiment, split, and pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06a233bd-12c8-415a-a922-222a925733f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "device_widget = make_device_widget()\n",
    "task_widget, expe_widget = make_experiment_widgets()\n",
    "split_widget = make_split_widget()\n",
    "ckpt_widget = make_checkpoint_file_search_widget()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f63dd77-81d4-4245-90f3-0267371dea52",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Summarizing selected task, experiment, split, and checkpoint\n",
    "print(f\"You chose:\")\n",
    "print(f\"  - device={device_widget.value}\")\n",
    "print(f\"  - task={task_widget.value}\")\n",
    "print(f\"  - split={split_widget.value}\")\n",
    "print(f\"  - experiment={expe_widget.value}\")\n",
    "print(f\"  - ckpt={ckpt_widget.value}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1353339e-281c-48b9-b5df-c3abf9f1a64e",
   "metadata": {},
   "source": [
    "## Parsing the config files\n",
    "Hydra and OmegaConf are used to parse the `yaml` config files.\n",
    "\n",
    "❗Make sure you selected a **ckpt file relevant to your experiment** in the previous section. \n",
    "You can use our pretrained models for this, or your own checkpoints if you have already trained a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ebc1dbb-3c8e-451a-ac24-3cb8df0c28a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Parse the configs using hydra\n",
    "cfg = init_config(overrides=[\n",
    "    f\"experiment={task_widget.value}/{expe_widget.value}\",\n",
    "    f\"ckpt_path={ckpt_widget.value}\"\n",
    "])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47121b17-c514-40e9-b441-a6d695cce8c1",
   "metadata": {},
   "source": [
    "## Datamodule and model instantiation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdd9594b-6d5c-43f7-99e4-e37ae7e985c9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Instantiate the datamodule\n",
    "datamodule = hydra.utils.instantiate(cfg.datamodule)\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup()\n",
    "\n",
    "# Pick among train, val, and test datasets. It is important to note that\n",
    "# the train dataset produces augmented spherical samples of large \n",
    "# scenes, while the val and test dataset load entire tiles at once\n",
    "if split_widget.value == 'train':\n",
    "    dataset = datamodule.train_dataset\n",
    "elif split_widget.value == 'val':\n",
    "    dataset = datamodule.val_dataset\n",
    "elif split_widget.value == 'test':\n",
    "    dataset = datamodule.test_dataset\n",
    "else:\n",
    "    raise ValueError(f\"Unknown split '{split_widget.value}'\")\n",
    "\n",
    "# Print a summary of the datasets' classes\n",
    "dataset.print_classes()\n",
    "\n",
    "# Instantiate the model\n",
    "model = hydra.utils.instantiate(cfg.model)\n",
    "\n",
    "# Load pretrained weights from a checkpoint file\n",
    "if ckpt_widget.value is not None:\n",
    "    model = model._load_from_checkpoint(cfg.ckpt_path)\n",
    "\n",
    "# Move model to selected device\n",
    "model = model.eval().to(device_widget.value)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9546ad1-f819-44d9-b2b7-343e96b72efb",
   "metadata": {},
   "source": [
    "## Oracles on a tile sample\n",
    "We design oracles for estimating the maximum achievable performance of our superpoint-graph-clustering approach on a point cloud. Here, it is important to note that these metrics are computed on a tile but not on the entire dataset. The oracles are computed on a given superpoint partition level. Based on the quality of the partition, we estimate the following:\n",
    "\n",
    "- `semantic_segmentation_oracle`: assign to each superpoint the most frequent label among the points it contains\n",
    "- `panoptic_segmentation_oracle`: same as for semantic segmentation + assign each superpoint to the target instance it overlaps the most\n",
    "- `oracle_superpoint_clustering`: same as for semantic segmentation + assign to each edge the target affinity + compute the graph clustering to form instance predictions\n",
    "\n",
    "Of course, these oracles are affected by how the superpoint partition has been computed. Besides, the latter is also affected by the graph clustering parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21614c71-e332-439e-8fe7-bd51825e4e0e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the panoptic annotations for a tile from the dataset \n",
    "obj = dataset[0][1].obj"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "010bba12-6161-45ee-9af3-64183f51588b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the semantic segmentation oracle\n",
    "obj.semantic_segmentation_oracle(dataset.num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c1a4cba-bac7-485d-a72d-5afc9aea3958",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the panoptic segmentation oracle without graph clustering\n",
    "obj.panoptic_segmentation_oracle(dataset.num_classes, stuff_classes=dataset.stuff_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a43b3ccb-bd92-4c30-8a35-8ef4037bd0de",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the panoptic segmentation oracle with graph clustering\n",
    "oracle_superpoint_clustering(\n",
    "    dataset[0],\n",
    "    dataset.num_classes,\n",
    "    dataset.stuff_classes,\n",
    "    mode='pas',\n",
    "    graph_kwargs=dict(\n",
    "        radius=0.1),\n",
    "    partition_kwargs=dict(\n",
    "        regularization=0.1,\n",
    "        x_weight=1e-3,\n",
    "        cutoff=300))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47e4f497-91b8-4ad0-88dc-aaecf6b0f5e0",
   "metadata": {},
   "source": [
    "## Grid-searching partition parameters on a tile sample\n",
    "Our SuperCluster model is trained to predict the input for a graph clustering problem whose solution is a panoptic segmentation of the scene.\n",
    "Interestingly, with our formulation, the model is **only supervised with local node-wise and edge-wise objectives, without ever needing to compute an actual panoptic partition of the scene during training**.\n",
    "\n",
    "At inference time, however, we need to decide on some parameters for our graph clustering algorithm.\n",
    "To this end, a simple post-training grid-search can be used.\n",
    "\n",
    "We find that similar parameters maximize panoptic segmentation results on all our datasets. \n",
    "Here you, we provide utilities for helping you grid-search parameters yourself. See `grid_search_panoptic_partition` docstring for more details on how to use this tool."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16ddc046-dd0a-4366-8169-0e96b707e3da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Grid search graph clustering parameters\n",
    "output, partitions, results = grid_search_panoptic_partition(\n",
    "    model,\n",
    "    datamodule.val_dataset,\n",
    "    i_cloud=0,\n",
    "    graph_kwargs=dict(\n",
    "        radius=0.1),\n",
    "    partition_kwargs=dict(\n",
    "        regularization=[2e1, 1e1, 5],\n",
    "        x_weight=[5e-2, 1e-2, 1e-3, 1e-4],\n",
    "        cutoff=300),\n",
    "    mode='pas')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2200c0b8-9af0-40d3-828a-26e76bd2c2e4",
   "metadata": {},
   "source": [
    "## Running evaluation on a whole dataset\n",
    "The above grid search only computes the panoptic segmentation metrics on a single point cloud.\n",
    "In this section, we provide tools for computing the panoptic metrics on a whole dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb38437f-90cd-47e1-bd11-8a3eba25dfec",
   "metadata": {},
   "outputs": [],
   "source": [
    "panoptic, instance, semantic = compute_panoptic_metrics(\n",
    "    model,\n",
    "    datamodule,\n",
    "    stage='val',\n",
    "    graph_kwargs=dict(\n",
    "        radius=0.1),\n",
    "    partition_kwargs=dict(\n",
    "        regularization=1e1,\n",
    "        x_weight=5e-2,\n",
    "        cutoff=300))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b789ae6f-771d-4500-a7e0-6cd9e391bb45",
   "metadata": {},
   "source": [
    "### S3DIS 6-fold metrics\n",
    "For S3DIS 6-fold metrics, we provide the following utility for computing metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7e3e09a-cae1-41b1-b4bc-9fb5834368a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "fold_ckpt = {\n",
    "    1: \"/path/to/your/s3dis/checkpoint/fold_1.ckpt\",\n",
    "    2: \"/path/to/your/s3dis/checkpoint/fold_2.ckpt\",\n",
    "    3: \"/path/to/your/s3dis/checkpoint/fold_3.ckpt\",\n",
    "    4: \"/path/to/your/s3dis/checkpoint/fold_4.ckpt\",\n",
    "    5: \"/path/to/your/s3dis/checkpoint/fold_5.ckpt\",\n",
    "    6: \"/path/to/your/s3dis/checkpoint/fold_6.ckpt\",\n",
    "}\n",
    "\n",
    "experiment_config = f\"experiment={task_widget.value}/{expe_widget.value}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "903673b0-4ad0-4bca-b713-b004a7d03309",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ = compute_panoptic_metrics_s3dis_6fold(\n",
    "    fold_ckpt,\n",
    "    experiment_config,\n",
    "    stage='val', \n",
    "    graph_kwargs=dict(\n",
    "        radius=0.1),\n",
    "    partition_kwargs=dict(\n",
    "        regularization=10,\n",
    "        x_weight=1e-3,\n",
    "        cutoff=300),\n",
    "    verbose=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:spt] *",
   "language": "python",
   "name": "conda-env-spt-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
